import pandas as pd
import numpy as np
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from fiftyone import ViewField as F

import plotly.express as px


def plot_f1_map_face_sizes(channel):
    # Specify dataset
    dataset_orig = fo.load_dataset(channel)
    us_dataset_list = ["CNNW", "FOXNEWSW", "MSNBCW"]
    # Do "evaluate_detections" to compute iou to be able to threshold wrt iou for US data evaluation
    fig_name = 'fig6'
    if dataset_orig.name in us_dataset_list:
        fig_name = 'fig5'
        dataset_orig.evaluate_detections("yolo-resnetv1-fcg_average_vote", "ground_truth", eval_key="eval", classwise=False)

    # Years list (2000-2022)
    years_list = [str(i) for i in range(2000, 2022)]
    view_analysis = dataset_orig.match(F("year").is_in(years_list))

    # For US evaluation
    # Filter the detections based on the IoU threshold
    if dataset_orig.name in us_dataset_list:
        view_analysis = view_analysis.filter_labels("yolo-resnetv1-fcg_average_vote", F("eval_iou") > 0.001).clone()

    # Generate different views depending on the bounding box sizes 
    bbox_area = (
        F("$metadata.width") * F("bounding_box")[2] *
        F("$metadata.height") * F("bounding_box")[3]
    )
    # [very small, small, small-medium, medium, medium-large, large, very large]
    # Average bbox for NHK = 78x78, HODO = 52x52. US dataset around 135 x 135.
    # Smallest NHK = 3x3, HODO = 2x2. US = 35x35
    # Largest NHK = 258x258, HODO = 174x174. US = 390x390

    boxes_areas = list(map(int, list(np.asarray([8, 16, 32, 64, 96, 128, 156]) ** 2)))
    boxes_filter_list = []

    for i in range(len(boxes_areas)):
        if i == 0:
            # First case
            boxes_filter = bbox_area <= boxes_areas[i]
        else:
            # Cases in the middle
            boxes_filter = (bbox_area > boxes_areas[i-1]) & (bbox_area <= boxes_areas[i])

        boxes_filter_list.append(boxes_filter)
            
    # Last case
    boxes_filter_list.append(bbox_area > boxes_areas[-1])

    # Generate views that contains only the filtered bboxes depending on size
    views_list = []

    for box_filter in boxes_filter_list:
    #for box_filter in [small_boxes, medium_boxes]:
        view_filtered = (
            view_analysis
            .filter_labels("ground_truth", box_filter)
            .filter_labels("yolo-resnetv1-fcg_average_vote", box_filter)
            .filter_labels("yolo-resnetv1-fcg_average_vote", F("label") != "-1")
        )
        views_list.append(view_filtered)

    # Run evaluation for the generated filtered views
    results_list = []
    if dataset_orig.name in us_dataset_list:
        iou_threshs = [0.4, 0.45, 0.5, 0.55, 0.6]
    else:
        iou_threshs = None

    for view_filtered in views_list:
        results_filtered = view_filtered.evaluate_detections(
            "yolo-resnetv1-fcg_average_vote",
            gt_field="ground_truth",
            eval_key="eval",
            compute_mAP=True,
            iou_threshs=iou_threshs,  # For US evaluation
        )

        results_list.append(results_filtered)

    # Print numerical results
    rows_df = []
    # 186**2 is for visualization purposes, representing [156-]
    for res, box_area in zip(results_list, boxes_areas + [186**2]):
        res_map = round((max(res.mAP(), 0) * 100), 1)
        res_f1 = round(res.metrics()['fscore'], 3)
        box_size = int(np.sqrt(box_area))
        rows_df.append([box_area, box_size, res_map, res_f1])
        print(f"mAP: {res_map}, F1: {res_f1}")

    df_res = pd.DataFrame(data=rows_df, columns=['area', 'box_size', 'map', 'f1'])
    print(df_res)

    # Plot mAP and F1 score for different face sizes.
    # Note that figures 5-6 in the paper are post-processed to be the superposition of the channels.
    # Name conversion
    out_name_ = "NHK"
    if dataset_orig.name == "hodost-lv":
        out_name_ = "HODO"
    elif dataset_orig.name == "CNNW":
        out_name_ = "CNN" 
    elif dataset_orig.name == "FOXNEWSW":
        out_name_ = "FOX" 
    elif dataset_orig.name == "MSNBCW":
        out_name_ = "MSNBC" 

    # Ticks
    box_sizes = np.sqrt(boxes_areas[1:] + [186**2])

    # mAP
    fig = px.line(df_res, x="box_size", y="map", text="map", title=f"mAP per bounding box size for {out_name_}")
    fig.update_traces(textposition="bottom right")

    fig.update_xaxes(
        title="Bounding box size",
        range=[0, 210],
        tickvals=box_sizes,
        ticktext=['[8-16]', '[16-32]', '[32-64]', '[64-96]', '[96-128]', '[128-156]', '[156-]'],
        tickangle=0
    )
    fig.update_yaxes(
        title="mAP"
    )

    fig.write_image(f"/results/{fig_name}-map_{out_name_}.pdf")
    #fig.show()

    # F1 score
    fig = px.line(df_res, x="box_size", y="f1", text="f1", title=f"F-score per bounding box size for {out_name_}")
    fig.update_traces(textposition="bottom right")

    fig.update_xaxes(
        title="Bounding box size",
        range=[0, 210],
        tickvals=box_sizes,
        ticktext=['[8-16]', '[16-32]', '[32-64]', '[64-96]', '[96-128]', '[128-156]', '[156-]'],
        tickangle=0
    )
    fig.update_yaxes(
        title="F-score"
    )

    fig.write_image(f"/results/{fig_name}-f1_{out_name_}.pdf")
    #fig.show()


def main():
    print(fo.list_datasets())

    #### Configuration options ####
    channel_list = ["CNNW", "FOXNEWSW", "MSNBCW", "news7-lv", "hodost-lv"]  # news7-lv (NHK), hodost-lv (HODO Station), CNNW (CNN), FOXNEWSW (FOX), MSNBCW (MSNBC)

    for channel in channel_list:
        plot_f1_map_face_sizes(channel)


if __name__ == "__main__":
    main()

